module net.BurtonRadons.dig.common.expressionEvaluator;

private import net.BurtonRadons.dig.platform.base;

/** A quick expression evaluator.  Handles all arithmetic, all D numbers,
  * predefined constants and functions.
  */

class ExpressionEvaluator
{
    import std.math;
    import std.ctype;
    import std.string;

    /** A custom function that takes a std.math.single double and returns a real. */
    extern (C) typedef real function (double v) FuncED;
    
    /** A custom function that takes two doubles and returns a real. */
    extern (C) typedef real function (double a, double b) FuncEDD;

    /** Define a constant. */
    void define (char [] name, float value)
    {
        digCommonConstants [name] = value;
    }

    /** Define a function which takes a std.math.single double argument. */
    void define (char [] name, FuncED func)
    {
        digCommonListED [name] = func;
    }

    /** Define a function which takes two doubles as arguments. */
    void define (char [] name, FuncEDD func)
    {
        digCommonListEDD [name] = func;
    }

    /** Define various mathematics expressions. */
    void predefineMath ()
    {
        define ("e", (double) E);
        define ("pi", (double) std.math.PI);
        define ("inf", float.infinity);
        define ("nan", float.nan);

        define ("acos", &acos);
        define ("asin", &asin);
        define ("atan", &atan);
        define ("atan2", &atan2);
        define ("cos", &mcos);
        define ("sin", &msin);
        define ("tan", &mtan);
        define ("cosh", &cosh);
        define ("sinh", &sinh);
        define ("tanh", &tanh);
        define ("exp", &std.math.exp);
        define ("log", &log);
        define ("log10", &log10);
        define ("pow", &pow);
        define ("sqrt", &msqrt);
        define ("ceil", &ceil);
        define ("floor", &floor);
        define ("log1p", &log1p);
        define ("expm1", &expm1);
        define ("hypot", &mhypot);
    }

    /** Evaluate and set the result in value, or return false if the
      * operation could not be completed.  value may still hold some
      * kind of value.
      */

    bit eval (char [] text, out float value)
    {
        char *s = text;
        char *e = s + text.length;

        if (!exp (value, s, e))
            return false;
        skipSpaces (s, e);
        return (s >= e);
    }

    /** Skip all spaces. */
    void skipSpaces (inout char *s, char *e)
    {
        while (s < e && isspace (*s))
            s ++;
    }
    
/+
#ifdef DoxygenMustSkipThis
+/

    static extern (C) real mcos (double v) { return std.math.cos (v); }
    static extern (C) real mtan (double v) { return tan (v); }
    static extern (C) real msin (double v) { return std.math.sin (v); }
    static extern (C) real msqrt (double v) { return std.math.sqrt (v); }
    static extern (C) real mhypot (double x, double y) { return hypot (x, y); }

    /* conditionalExp
     * conditionalExp , exp
     */

    bit exp (inout float value, inout char *s, char *e)
    {
        if (!conditionalExp (value, s, e))
            return false;

        while (1)
        {
            skipSpaces (s, e);
            if (s >= e)
                return true;
            if (*s == ',')
            {
                s ++;
                if (!conditionalExp (value, s, e))
                    return false;
            }
            else
                return true;
        }
    }

    /* orOrExp
     * orOrExp ? Exp : conditionalExp
     */

    bit conditionalExp (inout float value, inout char *s, char *e)
    {
        if (!orOrExp (value, s, e))
            return false;
        while (1)
        {
            skipSpaces (s, e);
            if (s >= e)
                return true;
            if (*s == '?')
            {
                float a, b;

                s ++;
                if (!exp (a, s, e))
                    return false;
                if (s >= e || *s != ':')
                    return false;
                s ++;
                if (!orOrExp (b, s, e))
                    return false;
                value = value ? a : b;
            }
            else
                return true;
        }
    }

    /* andAndExp
     * andAndExp || orOrExp
     */

    bit orOrExp (inout float value, inout char *s, char *e)
    {
        float sub;

        if (!andAndExp (value, s, e))
            return false;
        while (1)
        {
            skipSpaces (s, e);
            if (s >= e - 1)
                return true;
            if (*s == '|' && s [1] == '|')
            {
                s += 2;
                if (!andAndExp (sub, s, e))
                    return false;
                value = value || sub;
            }
            else
                return true;
        }
    }

    /* orExp
     * orExp && andAndExp
     */

    bit andAndExp (inout float value, inout char *s, char *e)
    {
        float sub;

        if (!orExp (value, s, e))
            return false;
        while (1)
        {
            skipSpaces (s, e);
            if (s >= e - 1)
                return true;
            if (*s == '&' && s [1] == '&')
            {
                s += 2;
                if (!orExp (sub, s, e))
                    return false;
                value = value && sub;
            }
            else
                return true;
        }
    }

    /* xorExp
     * xorExp | orExp
     */

    bit orExp (inout float value, inout char *s, char *e)
    {
        float sub;

        if (!xorExp (value, s, e))
            return false;
        while (1)
        {
            skipSpaces (s, e);
            if (s >= e)
                return true;
            if (*s == '|' && (s >= e - 1 || s [1] != '|'))
            {
                s ++;
                if (!xorExp (sub, s, e))
                    return false;
                value = (ulong) value | (ulong) sub;
            }
            else
                return true;
        }
    }

    /* andExp
     * andExp ^ xorExp
     */

    bit xorExp (inout float value, inout char *s, char *e)
    {
        float sub;

        if (!andExp (value, s, e))
            return false;
        while (1)
        {
            skipSpaces (s, e);
            if (s >= e)
                return true;
            if (*s == '^')
            {
                s ++;
                if (!andExp (sub, s, e))
                    return false;
                value = (ulong) value ^ (ulong) sub;
            }
            else
                return true;
        }
    }

    /* equalExp
     * equalExp & andExp
     */

    bit andExp (inout float value, inout char *s, char *e)
    {
        float sub;

        if (!equalExp (value, s, e))
            return false;
        while (1)
        {
            skipSpaces (s, e);
            if (s >= e)
                return true;
            if (*s == '&' && (s >= e - 1 || s [1] != '&'))
            {
                s ++;
                if (!equalExp (sub, s, e))
                    return false;
                value = (ulong) value & (ulong) sub;
            }
            else
                return true;
        }
    }

    /* relExp
     * relExp == equalExp
     * relExp != equalExp
     */

    bit equalExp (inout float value, inout char *s, char *e)
    {
        float sub;

        if (!relExp (value, s, e))
            return false;
        while (1)
        {
            skipSpaces (s, e);
            if (s >= e - 1)
                return true;
            if (*s == '=' && s [1] == '=')
            {
                s += 2;
                if (!relExp (sub, s, e))
                    return false;
                value = (value == sub);
            }
            else if (*s == '!' && s [1] == '=')
            {
                s += 2;
                if (!relExp (sub, s, e))
                    return false;
                value = (value != sub);
            }
            else
                return true;
        }
    }

    /* shiftExp
     */

    bit relExp (inout float value, inout char *s, char *e)
    {
        return shiftExp (value, s, e);
    }

    /* addExp
     * addExp << shiftExp
     * addExp >> shiftExp
     */

    bit shiftExp (inout float value, inout char *s, char *e)
    {
        float sub;

        if (!addExp (value, s, e))
            return false;
        while (1)
        {
            skipSpaces (s, e);
            if (s >= e - 1)
                return true;
            if (s [0] == '<' && s [1] == '<')
            {
                s += 2;
                if (!addExp (sub, s, e))
                    return false;
                value = value * pow(cast (double) 2, cast (double) sub);
            }
            else if (s [0] == '>' && s [1] == '>')
            {
                s += 2;
                if (!addExp (sub, s, e))
                    return false;
                value = value / pow(cast (double) 2, cast (double) sub);
            }
            else
                return true;
        }
    }

    /* mulExp
     * mulExp + addExp
     * mulExp - addExp
     */

    bit addExp (inout float value, inout char *s, char *e)
    {
        float sub;

        if (!mulExp (value, s, e))
            return false;
        while (1)
        {
            skipSpaces (s, e);
            if (s >= e)
                return true;

            if (*s == '+')
            {
                s ++;
                if (!mulExp (sub, s, e))
                    return false;
                value += sub;
            }
            else if (*s == '-')
            {
                s ++;
                if (!mulExp (sub, s, e))
                    return false;
                value -= sub;
            }
            else
                return true;
        }
    }

    /* powExp
     * powExp * mulExp
     * powExp / mulExp
     * powExp % mulExp
     */

    bit mulExp (inout float value, inout char *s, char *e)
    {
        float sub;

        if (!powExp (value, s, e))
            return false;

        while (1)
        {
            skipSpaces (s, e);
            if (s >= e)
                return true;

            if (*s == '*' && (s == e - 1 || s [1] != '*'))
            {
                s ++;
                if (!powExp (sub, s, e))
                    return false;
                value *= sub;
            }
            else if (*s == '/')
            {
                s ++;
                if (!powExp (sub, s, e))
                    return false;
                value /= sub;
            }
            else if (*s == '%')
            {
                s ++;
                if (!powExp (sub, s, e))
                    return false;
                value %= sub;
            }
            else
                return true;
        }
    }

    /* unaryExp
     * unaryExp ** powExp
     */

    bit powExp (inout float value, inout char *s, char *e)
    {
        float sub;

        if (!unaryExp (value, s, e))
            return false;

        while (1)
        {
            skipSpaces (s, e);
            if (s >= e)
                return true;

            if (*s == '*' && s <= e - 1 && s [1] == '*')
            {
                s += 2;
                if (!unaryExp (sub, s, e))
                    return false;
                value = pow(cast (double) value, cast (double) sub);
            }
            else
                return true;
        }
    }

    /* numberExp
     * - unaryExp
     * + unaryExp
     * ~ unaryExp
     * ! unaryExp
     * ( mulExp )
     * identifierExp
     */

    bit unaryExp (inout float value, inout char *s, char *e)
    {
        skipSpaces (s, e);

        if (s >= e)
            return false;

        if (isdigit (*s) || *s == '.')
            return numberExp (value, s, e);

        if (*s == '-')
        {
            s ++;
            if (!unaryExp (value, s, e))
                return false;
            value = -value;
            return true;
        }

        if (*s == '+')
        {
            s ++;
            if (!unaryExp (value, s, e))
                return false;
            value = +value;
            return true;
        }

        if (*s == '~')
        {
            s ++;
            if (!unaryExp (value, s, e))
                return false;
            value = ~(long) value;
            return true;
        }

        if (*s == '!')
        {
            s ++;
            if (!unaryExp (value, s, e))
                return false;
            value = !value;
            return true;
        }

        if (*s == '(')
        {
            s ++;
            if (!conditionalExp (value, s, e))
                return false;
            if (s >= e || *s != ')')
                return false;
            s ++;
            return true;
        }

        if (isalpha (*s) || *s == '_')
            return identifierExp (value, s, e);

        return false;
    }

    /** identifier
      * identifier ( args... )
      */

    bit identifierExp (inout float value, inout char *s, char *e)
    {
        char *t = s;
        char [] id;
        float a, b;

        for (s ++; s < e && (isalnum (*s) || *s == '_'); s ++)
            { }

        id = std.string.tolower (t [0 .. (int) (s - t)].dup);
        if (id in digCommonConstants)
        {
            value = digCommonConstants [id];
            return true;
        }

        if (id in digCommonListED)
        {
            skipSpaces (s, e);
            if (s >= e || *s != '(')
                return false;
            s ++;

            if (!conditionalExp (a, s, e))
                return false;

            if (s >= e || *s != ')')
                return false;
            s ++;

            value = (double) digCommonListED [id] (a);
            return true;
        }

        if (id in digCommonListEDD)
        {
            skipSpaces (s, e);
            if (s >= e || *s != '(')
                return false;
            s ++;

            if (!conditionalExp (a, s, e))
                return false;
            if (s >= e || *s != ',')
                return false;
            s ++;

            if (!conditionalExp (b, s, e))
                return false;
            if (s >= e || *s != ')')
                return false;
            s ++;

            value = (double) digCommonListEDD [id] (a, b);
            return true;
        }

        return false;
    }

    bit numberExp (inout float value, inout char *s, char *e)
    {
        char [] digits = "0123456789abcdef";
        int base = 10;
        int d, endskip = 0;

        real number = 0, num = 0, den = 1, exp = 0, expSign = +1, expBase = 10;
        char expDigit = 'e';

        if (s < e - 1 && *s == '0' && (s [1] == 'x' || s [1] == 'X'))
        {
            base = 16;
            expDigit = 'p';
            expBase = 2;
            s += 2;
        }
        else if (s < e - 1 && *s == '0' && (s [1] == 'b' || s [1] == 'B'))
        {
            base = 2;
            s += 2;
        }
        else if (isdigit (*s))
        {
            for (int c; ; c ++)
                if (s + c >= e)
                {
                    if (*s == '0')
                    {
                        s ++;
                        base = 8;
                    }
                    break;
                }
                else if (s [c] == 'h' || s [c] == 'H')
                {
                    base = 16;
                    endskip = 1;
                    break;
                }
        }

        if (base > 10)
            expDigit = 'p';
        digits = digits [0 .. base];

        while (s < e && (d = find (digits, std.ctype.tolower (s [0]))) != -1)
            number = number * base + d,
            s ++;

        if (*s == '.')
        {
            s ++;
            while (s < e && (d = find (digits, std.ctype.tolower (s [0]))) != -1)
                num = num * base + d,
                den = den * base,
                s ++;
        }

        if (std.ctype.tolower (*s) == expDigit)
        {
            s ++;
            if (s < e && *s == '+')
                s ++;
            else if (s < e && *s == '-')
                s ++, expSign *= -1;

            while (s < e && (d = find (digits, std.ctype.tolower (s [0]))) != -1)
                exp = exp * base + d,
                s ++;
        }

        s += endskip;
        value = (number + num / den) * pow(cast (double) expBase, cast (double) (exp * expSign));
        return true;
    }

    float [char []] digCommonConstants;

    FuncED [char []] digCommonListED;
    FuncEDD [char []] digCommonListEDD;
    
    void digCommonTest (char [] text, bit success, float value)
    {
        float v;
        
        if (eval (text, v) != success)
            throw new Error ("'" ~ text ~ "' unexpectedly " ~ (success ? (char []) "failed" : (char []) "succeeded"));
        if (fabs (v - value) > std.math.fabs (value) / 100)
            throw new Error (fmt ("'%.*s' should have evaluated to %f but resulted in %f", text, value, v));
    }

    void digCommonTest (char [] text, float value)
    {
        digCommonTest (text, true, value);
    }

    void digCommonTestFail (char [] text)
    {
        digCommonTest (text, false, float.nan);
    }

    unittest
    {
        
        with (new ExpressionEvaluator ())
        {
            digCommonTest ("654.32", 654.32);
            digCommonTest ("5.284e3", 5.284e3);
            digCommonTest ("0x56.34", 0x56.34p0);
            digCommonTest ("0x32p1", 0x32p1);
            digCommonTest ("0b100101", 0b100101);
            digCommonTest ("1 + 400", 401);
            digCommonTest ("8 * -16", 8 * -16);
            digCommonTest ("2 + 3 * 4", 2 + 3 * 4);
            digCommonTest ("2 * 3 + 4", 2 * 3 + 4);
            digCommonTest ("2 * (3 + 4)", 2 * (3 + 4));
            digCommonTest ("43 & 83", 43 & 83);
            digCommonTest ("98 | 37", 98 | 37);
            digCommonTest ("91 ^ 26", 91 ^ 26);
            digCommonTest ("5 << 2", 5 << 2);
            digCommonTest ("8 >> 1", 8 >> 1);
            digCommonTest ("8 ** 3", 8 * 8 * 8);
        }
    }
    
/+
#endif
+/
}

